from unidecode import unidecode


def repair_edf_header(input_edf_path, output_edf_path):
    with open(input_edf_path, 'rb') as f:
        # Read the entire EDF file into memory
        edf_data = bytearray(f.read())

    # Read the patient ID field
    patient_id_bytes = edf_data[8:88]

    # Convert bytes to a string for checking illegal characters
    patient_id_str = patient_id_bytes.decode('gb2312', errors='ignore')

    # If the repaired ID is empty, assign it 'UNKNOWN'
    if patient_id_str == '':
        patient_id_str = 'UNKNOWN'

    patient_id_str = unidecode(patient_id_str)

    # Ensure it doesn't exceed 80 characters and pad to 80
    cleaned_patient_id = patient_id_str[:80].ljust(80)

    # Convert the string back to bytes
    cleaned_patient_id_bytes = cleaned_patient_id.encode('ascii')

    # Replace the patient ID field in memory
    edf_data[8:88] = cleaned_patient_id_bytes[:80]

    # Parse the number of signals
    # According to EDF specifications, bytes 236-244 in the file header contain signal count information
    num_signals_str = edf_data[252:256].decode('ascii').strip()
    num_signals = int(num_signals_str)

    # Locate the start position of signal labels
    # According to EDF specifications, the first 256 bytes are fixed header information, and each signal label is 16 bytes
    label_position = 256
    physical_position = 256 + num_signals * (16 + 80 + 8)
    physical_dimension_position = 256 + num_signals * (16 + 80)

    # Iterate through each signal label
    for i in range(num_signals):
        # Handle garbled characters in channel name and units
        repair_label(edf_data, i, label_position)
        repair_physical_dimension(edf_data, i, physical_dimension_position)
        # Handle cases where physical values are equal
        repair_physical_val(edf_data, i, physical_position, num_signals)

    # Write the contents from memory into the new EDF file
    with open(output_edf_path, 'wb') as f:
        f.write(edf_data)


def repair_label(edf_data, i, label_position):
    # Calculate the position of this signal label
    label_start = label_position + i * 16
    label_end = label_start + 16

    # Read the signal label
    label_bytes = edf_data[label_start:label_end]

    # Convert bytes to a string for checking illegal characters
    label_str = label_bytes.decode('gb2312', errors='ignore')

    # Clear non-ASCII characters: keep only ASCII characters
    cleaned_label = ''.join(char for char in label_str if ord(char) < 128)

    # Compare if non-ASCII characters exist, use unidecode to process
    if len(label_str) != len(cleaned_label):
        cleaned_label = unidecode(label_str).replace(" ", '')

    # If the repaired label is empty, assign it 'CHANNEL' followed by the signal number
    if cleaned_label == '':
        cleaned_label = f'CHANNEL{i + 1}'

    # Adjust the length of the string to 16 characters, pad with spaces if less than 16
    cleaned_label = cleaned_label.ljust(16)

    # Convert the string back to bytes
    cleaned_label_bytes = cleaned_label.encode('ascii')

    # Replace the signal label in memory
    edf_data[label_start:label_end] = cleaned_label_bytes[:16]


def repair_physical_val(edf_data, i, physical_position, num_signals):
    # Fix cases where physical signal minimum and maximum values are equal
    physical_minimum_start = physical_position + i * 8
    physical_minimum_end = physical_minimum_start + 8

    physical_maximum_start = physical_position + num_signals * 8 + i * 8
    physical_maximum_end = physical_maximum_start + 8

    physical_minimum = float(edf_data[physical_minimum_start:physical_minimum_end].decode('ascii').strip())
    physical_maximum = float(edf_data[physical_maximum_start:physical_maximum_end].decode('ascii').strip())

    if physical_minimum == physical_maximum:
        physical_maximum = physical_maximum + 1
        physical_minimum = -physical_maximum
        edf_data[physical_minimum_start:physical_minimum_end] = str(physical_minimum).ljust(8).encode('ascii')
        edf_data[physical_maximum_start:physical_maximum_end] = str(physical_maximum).ljust(8).encode('ascii')


def repair_physical_dimension(edf_data, i, physical_dimension_position):
    physical_dimension_start = physical_dimension_position + i * 8
    physical_dimension_end = physical_dimension_start + 8

    # Read the unit label
    physical_dimension_bytes = edf_data[physical_dimension_start:physical_dimension_end]
    physical_dimension_str = physical_dimension_bytes.decode('ascii', errors='ignore')
    physical_dimension_str = ''.join(char for char in physical_dimension_str if ord(char) < 128).ljust(8)
    # Convert the string back to bytes
    cleaned_physical_dimension_bytes = physical_dimension_str.encode('ascii')
    # Replace the unit label in memory
    edf_data[physical_dimension_start:physical_dimension_end] = cleaned_physical_dimension_bytes[:8]
